#' ---
#' title: Reproduce Appendix Figure A1 (Twitter Sentiment Analysis Task, Varying Prompt and Model Variant)
#' date: 2024-10-02
#' version: 1.0
#' ---

library(tidyverse)
library(patchwork)
set.seed(42)

## Function to load the appropriate variant dataset and plot results

subplot <- function(variant = 'Few-Shot Ada'){
  
  filename <- variant |> 
    str_to_lower() |> 
    str_replace_all(' ', '-')
  
  d <- read_csv(paste0('appendix-A-', filename, '.csv'))
  
  # create a "long" dataframe with each token and its predicted probability
  tokens <- d |> 
    select(tweet_id, token1:token5) |> 
    pivot_longer(cols = token1:token5,
                 values_to = 'sentiment') |> 
    select(-name)
  
  probs <- d |> 
    select(tweet_id, prob1:prob5) |> 
    pivot_longer(cols = prob1:prob5,
                 values_to = 'probability') |> 
    select(-name)
  
  tokens$probability <- probs$probability 
  
  # remove whitespace and capitalize
  tokens <- tokens |> 
    mutate(sentiment = str_trim(sentiment)) |> 
    mutate(sentiment = str_to_title(sentiment)) |> 
    mutate(sentiment = if_else(sentiment == 'Neg', 'Negative', sentiment),
           sentiment = if_else(sentiment == 'Pos', 'Positive', sentiment))
  
  # compute GPT-3 sentiment score by taking the first component in a PCA
  gpt_3_sentiment <- tokens |> 
    group_by(tweet_id, sentiment) |> 
    summarize(probability = sum(probability)) |> 
    filter(sentiment %in% c('Positive', 'Negative', 'Neutral')) |> 
    pivot_wider(names_from = 'sentiment',
                values_from = 'probability', 
                values_fill = 0) |> 
    filter(!is.na(Negative)) |> 
    ungroup()
  
  p <- gpt_3_sentiment |> 
    select(Negative, Neutral, Positive) |> 
    princomp()
  
  gpt_3_sentiment$gpt_3_score <- -1 * p$scores[,1]
  
  # merge with full dataset
  d <- left_join(d, gpt_3_sentiment, by = 'tweet_id')
  
  # compute expert score
  d$expert_code <- (d$ornstein_code + d$blasingame_code + d$truscott_code) / 3
  
  correlation <- cor(d$gpt_3_score, d$expert_code, 
                     use = 'pairwise.complete.obs')
  
  # return plot
  ggplot(data = d,
         mapping = aes(x=expert_code, 
                       y=gpt_3_score)) +
    geom_jitter(alpha = 1/10, width = 0.1, height = 0) +
    theme(panel.border = element_rect(colour = "black", fill=NA, linewidth=1)) +
    labs(x = 'Expert Positivity', 
         y = 'GPT-3 Positivity',
         title = paste0(variant, ' (\u03C1 = ', round(correlation, 3), ')')) +
    theme_minimal() + 
    theme_bw() + 
    geom_smooth(method = 'lm', 
                se = FALSE, 
                color = 'gray20') +
    theme(plot.title = element_text(hjust = 0.5))
  
}

p <- subplot('Few-Shot Ada') + subplot('Few-Shot Babbage') +
  subplot('Few-Shot Curie') + subplot('Few-Shot Davinci') +
  subplot('One-Shot Ada') + subplot('One-Shot Babbage') +
  subplot('One-Shot Curie') + subplot('One-Shot Davinci') +
  subplot('Zero-Shot Ada') + subplot('Zero-Shot Babbage') +
  subplot('Zero-Shot Curie') + subplot('Zero-Shot Davinci') +
  plot_layout(axes = 'collect', axis_titles = 'collect')

ggsave(plot = p, 
       filename = 'figure-A1.png',
       width = 18, height = 12)
